// This is a rough example of how to use the autodiff crate
// to calibrate a model to market data.
// The model is the Black-Scholes model, and the market data is
// the strikes and prices of call options on AAPL.
// We are simply calibrating the volatility parameter of the model.
//
// Run: cargo run --release --example calibration

use std::time::Instant;
use RustQuant::autodiff::*;
use RustQuant::math::optimization::gradient_descent::GradientDescent;

fn main() {
    calibrate();
}

#[inline]
pub(crate) fn calibrate() {
    let graph = Graph::new();

    let gd = GradientDescent::new(0.001, 1000, None);

    let start = Instant::now();
    let result = gd.optimize(mse, &[0.1], true);
    let duration = start.elapsed();

    println!("MSE  = \t {}", mse(&graph.vars(&[0.1])));
    println!("Vol  = \t {:?}", result.minimizer.first().unwrap());
    println!("Time = \t {:?}", duration);
}

fn mse<'v>(v: &[Variable<'v>]) -> Variable<'v> {
    // HARD CODED PARAMETERS: as at July 21, 2023
    let s = 192.64;
    let r = 0.0525;
    let d = 0.005;

    // HARD CODED DATA FROM YAHOO FINANCE (See bottom of file).
    // These are June 23, 2023 call options for AAPL
    let i = 0; // 0 = July 21, 2023                         <-- CHANGE THIS TO CHANGE THE DATE
    let t = DATA[i].0 as f64 / 247.;
    let strikes = DATA[i].1;
    let prices = DATA[i].2;

    // Compute the squared errors:
    // MSE(Model - Market)
    let se = strikes
        .iter()
        .copied()
        .zip(prices.iter().copied())
        .map(|(strike, price)| {
            (black_scholes(s, strike, t, r, v[0], d, TypeFlag::CALL) - price).powf(2.)
        })
        .sum::<Variable>();

    // Return the MSE/Variable<'v>,
    se / (strikes.len() as f64)
}

#[allow(non_snake_case)]
#[inline]
fn N(x: Variable<'_>) -> Variable<'_> {
    0.5 * (-x / core::f64::consts::SQRT_2).erfc()
}

#[allow(non_snake_case)]
#[inline]
fn black_scholes(
    S: f64,
    K: f64,
    T: f64,
    r: f64,
    v: Variable<'_>,
    d: f64,
    type_flag: TypeFlag,
) -> Variable<'_> {
    let d1 = ((S / K).ln() + (r - d + v * v / 2.0) * T) / (v * T.sqrt());
    let d2 = d1 - v * T.sqrt();

    match type_flag {
        TypeFlag::CALL => S * (-d * T).exp() * N(d1) - K * (-r * T).exp() * N(d2),
        TypeFlag::PUT => -S * (-d * T).exp() * N(-d1) + K * (-r * T).exp() * N(-d2),
    }
}

#[allow(dead_code)]
#[allow(clippy::upper_case_acronyms)]
enum TypeFlag {
    CALL,
    PUT,
}

// OPTION DATA FROM YAHOO FINANCE
// ARRAY OF TUPLES: (DAYS_TO_EXPIRY, STRIKES, PRICES)
// Company: AAPL
const DATA: &[(usize, &[f64], &[f64]); 11] = &[
    // Jul 21 2023 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    (
        1,
        &[
            50., 55., 60., 65., 70., 75., 80., 85., 90., 95., 100., 105., 110., 115., 120., 125.,
            130., 135., 140., 145., 150., 155., 157.5, 160., 162.5, 165., 167.5, 170., 172.5, 175.,
            177.5, 180., 182.5, 185., 187.5, 190., 192.5, 195., 197.5, 200., 202.5, 205., 207.5,
            210., 212.5, 215., 217.5, 220., 225., 230., 235., 240., 245., 250., 255., 260., 270.,
            280.,
        ],
        &[
            144.85, 139.85, 134.3, 123.1, 124.15, 118.25, 113.25, 108.25, 103.25, 98.45, 94.4,
            88.02, 83.59, 78.56, 74.3, 69.22, 63.5, 58.65, 53.2, 48.4, 43.25, 38.78, 36.3, 33.25,
            30.83, 28., 25.58, 23., 20.94, 18.29, 15.96, 12.85, 11.11, 7.95, 5.45, 2.95, 0.65,
            0.04, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01,
            0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01,
        ],
    ),
    // Jul 28 2023 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    (
        6,
        &[
            50., 60., 65., 70., 75., 80., 85., 90., 100., 105., 110., 115., 120., 125., 130., 135.,
            140., 145., 150., 155., 160., 162.5, 165., 167.5, 170., 172.5, 175., 177.5, 180.,
            182.5, 185., 187.5, 190., 192.5, 195., 197.5, 200., 202.5, 205., 207.5, 210., 212.5,
            215., 217.5, 220., 225., 230., 235., 240., 245., 250., 255., 260.,
        ],
        &[
            143.0, 130.2, 122.95, 120.0, 114.7, 108.1, 104.9, 100.35, 94.25, 88.03, 80.35, 72.75,
            67.45, 69.34, 60.8, 58.62, 53.27, 50.2, 43.3, 38.35, 33.17, 30.85, 28.53, 27.3, 23.61,
            21.15, 18.68, 16.19, 13.62, 11.42, 8.76, 6.45, 4.45, 2.52, 1.33, 0.64, 0.28, 0.14,
            0.08, 0.06, 0.04, 0.03, 0.02, 0.02, 0.02, 0.01, 0.01, 0.01, 0.02, 0.01, 0.01, 0.01,
            0.01,
        ],
    ),
    // Aug 04 2023 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    (
        11,
        &[
            50., 60., 65., 70., 75., 80., 85., 90., 95., 100., 105., 110., 115., 120., 125., 130.,
            135., 140., 145., 150., 155., 160., 162.5, 165., 167.5, 170., 172.5, 175., 177.5, 180.,
            182.5, 185., 187.5, 190., 192.5, 195., 197.5, 200., 202.5, 205., 207.5, 210., 212.5,
            215., 217.5, 220., 225., 230., 235., 240., 245., 250., 255., 260.,
        ],
        &[
            143.25, 133.3, 123.95, 119.2, 114., 110.05, 105.35, 100.4, 95.4, 95.17, 85.15, 80.66,
            78.79, 70.05, 65.7, 65.2, 59.38, 53.45, 48.99, 43.32, 39.22, 34.2, 31.3, 29.14, 26.1,
            24.38, 21.47, 18.99, 17.17, 14.4, 12.5, 9.95, 8.17, 6.46, 4.72, 3.45, 2.56, 1.75, 1.19,
            0.85, 0.6, 0.4, 0.28, 0.19, 0.17, 0.12, 0.07, 0.06, 0.05, 0.03, 0.02, 0.03, 0.02, 0.01,
        ],
    ),
    // Aug 11 2023 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    (
        16,
        &[
            50., 80., 95., 110., 130., 135., 140., 145., 150., 155., 160., 165., 170., 175., 180.,
            185., 190., 195., 200., 205., 210., 215., 220., 225., 230., 235., 240., 245., 250.,
            255., 260.,
        ],
        &[
            138.21, 114.55, 99.95, 83.09, 63.34, 58.42, 55.59, 51.02, 45.42, 38.95, 34.16, 29.29,
            24.25, 19.29, 14.72, 10.48, 7.2, 3.9, 2.2, 1., 0.5, 0.24, 0.16, 0.1, 0.09, 0.05, 0.04,
            0.03, 0.03, 0.02, 0.02,
        ],
    ),
    // Aug 18 2023 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    (
        20,
        &[
            50., 55., 60., 65., 70., 75., 80., 85., 90., 95., 100., 105., 110., 115., 120., 125.,
            130., 135., 140., 145., 150., 155., 160., 165., 170., 175., 180., 185., 190., 195.,
            200., 205., 210., 215., 220., 225., 230., 235., 240., 245., 250., 255., 260., 265.,
            270., 275., 280., 285.,
        ],
        &[
            124.8, 132.9, 135.75, 120.53, 118.3, 100.11, 111.15, 106.47, 101.95, 96.45, 93.85,
            86.1, 83.86, 76.56, 73.4, 70.2, 65.71, 58.18, 54.07, 48.45, 44.1, 39.23, 34., 29.13,
            24.15, 19.51, 15.2, 11.07, 6.95, 4.25, 2.32, 1.19, 0.59, 0.29, 0.18, 0.13, 0.08, 0.07,
            0.05, 0.04, 0.04, 0.02, 0.03, 0.03, 0.02, 0.02, 0.02, 0.01,
        ],
    ),
    // Aug 25 2023 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    (
        25,
        &[
            80., 110., 115., 120., 125., 130., 135., 140., 145., 150., 155., 160., 165., 170.,
            175., 180., 185., 190., 195., 200., 205., 210., 215., 220., 225., 230., 235., 240.,
            245., 250., 255., 260.,
        ],
        &[
            115., 87., 76.35, 73.78, 63.05, 60.75, 58.45, 53.51, 49.08, 43.98, 39.09, 34.1, 29.11,
            24.5, 19.66, 15.45, 11.3, 7.8, 4.75, 2.8, 1.43, 0.75, 0.39, 0.24, 0.15, 0.11, 0.08,
            0.07, 0.06, 0.05, 0.03, 0.03,
        ],
    ),
    // Sep 01 2023 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    (
        30,
        &[
            100., 115., 125., 140., 145., 150., 155., 160., 165., 170., 175., 180., 185., 190.,
            195., 200., 205., 210., 215., 220., 225., 230., 235., 240., 245., 250., 260.,
        ],
        &[
            96.05, 76., 69.01, 54.59, 49.18, 44.1, 39.6, 34.5, 29.46, 25.69, 21.04, 15.6, 11.4,
            8.39, 5.35, 3.15, 1.71, 0.92, 0.5, 0.28, 0.19, 0.14, 0.12, 0.08, 0.06, 0.08, 0.03,
        ],
    ),
    // Sep 15 2023 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    (
        47,
        &[
            65., 70., 75., 80., 85., 90., 95., 100., 105., 110., 115., 120., 125., 130., 135.,
            140., 145., 150., 155., 160., 165., 170., 175., 180., 185., 190., 195., 200., 205.,
            210., 220., 230., 240., 250., 260., 270., 280., 290., 300.,
        ],
        &[
            131.05, 125.71, 120.65, 114.8, 110.25, 106.25, 98.65, 95.85, 90.2, 86., 79.66, 74.19,
            70.92, 64.3, 59.15, 53.65, 49.08, 44.85, 39.35, 35., 30.09, 25.57, 21.1, 16.31, 12.77,
            8.75, 5.85, 3.65, 2.15, 1.21, 0.35, 0.16, 0.1, 0.06, 0.04, 0.04, 0.03, 0.02, 0.02,
        ],
    ),
    // Oct 20 2023 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    (
        64,
        &[
            55., 60., 65., 70., 75., 80., 85., 90., 95., 100., 105., 110., 115., 120., 125., 130.,
            135., 140., 145., 150., 155., 160., 165., 170., 175., 180., 185., 190., 195., 200.,
            205., 210., 215., 220., 225., 230., 235., 240., 245., 250., 255., 260., 265., 270.,
            275., 280., 285., 290., 295.,
        ],
        &[
            137.03, 133.23, 87., 104.91, 115.55, 113.75, 105.37, 101.71, 99.3, 95.85, 88.41, 84.39,
            79.89, 75.11, 71.55, 67.07, 62.28, 55.65, 52.78, 45.9, 40.4, 36.22, 31.44, 26.82,
            22.17, 18.51, 14.25, 11.07, 7.85, 5.35, 3.6, 2.29, 1.41, 0.8, 0.5, 0.28, 0.2, 0.15,
            0.12, 0.09, 0.09, 0.08, 0.05, 0.05, 0.05, 0.04, 0.05, 0.03, 0.03,
        ],
    ),
    // Nov 17 2023 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    (
        82,
        &[
            50., 55., 60., 65., 70., 75., 80., 85., 90., 95., 100., 105., 110., 115., 120., 125.,
            130., 135., 140., 145., 150., 155., 160., 165., 170., 175., 180., 185., 190., 195.,
            200., 205., 210., 215., 220., 225., 230., 235., 240., 245., 250., 255., 260., 265.,
            270., 275., 280., 285., 290., 300.,
        ],
        &[
            145., 132.76, 126., 120.9, 116., 115.73, 102.43, 110.15, 100.07, 97.9, 92.8, 89.62,
            79.9, 74.07, 73.61, 71.2, 67.52, 56.95, 57.43, 50.9, 49.02, 42.03, 36.83, 33., 28.11,
            24.17, 20.2, 16.15, 13.1, 9.85, 7.45, 5.27, 3.45, 2.46, 1.53, 1.03, 0.63, 0.41, 0.29,
            0.25, 0.14, 0.13, 0.13, 0.12, 0.1, 0.08, 0.04, 0.07, 0.05, 0.05,
        ],
    ),
    // Dec 15 2023 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    (
        101,
        &[
            65., 70., 75., 80., 85., 90., 95., 100., 105., 110., 115., 120., 125., 130., 135.,
            140., 145., 150., 155., 160., 165., 170., 175., 180., 185., 190., 195., 200., 205.,
            210., 215., 220., 225., 230., 235., 240., 245., 250., 255., 260., 265., 270., 275.,
            280., 285., 290., 295.,
        ],
        &[
            129.01, 124.7, 107.85, 112.9, 109.27, 104.46, 94.7, 93.4, 86.42, 79.12, 80.92, 75.8,
            65.2, 65.64, 63.27, 58.65, 52., 47., 43.45, 37.8, 33.95, 29.4, 25.1, 21.16, 17.9,
            14.51, 11.49, 8.5, 6.45, 4.75, 3.15, 2.16, 1.46, 0.95, 0.66, 0.45, 0.33, 0.23, 0.19,
            0.14, 0.14, 0.11, 0.09, 0.07, 0.09, 0.08, 0.07,
        ],
    ),
];

// let strikes: &[f64] = &[
//     70.0, 100.0, 110.0, 115.0, 120.0, 125.0, 130.0, 135.0, 140.0, 145.0, 148.0, 149.0, 150.0,
//     152.5, 155.0, 157.5, 160.0, 162.5, 165.0, 167.5, 170.0, 172.5, 175.0, 177.5, 180.0, 182.5,
//     185.0, 187.5, 190.0, 192.5, 195.0, 197.5, 200.0, 202.5, 205.0, 207.5, 210.0, 215.0, 220.0,
//     225.0, 230.0, 235.0, 240.0, 245.0, 250.0, 255.0,
// ];
// let prices: &[f64] = &[
//     108.57, 79.80, 75.93, 70.92, 65.89, 60.20, 54.00, 50.73, 45.80, 40.65, 34.75, 33.80, 35.05,
//     33.24, 30.01, 27.73, 24.90, 23.21, 19.90, 17.57, 14.87, 12.40, 10.25, 7.81, 5.27, 3.26,
//     1.69, 0.67, 0.24, 0.08, 0.03, 0.03, 0.02, 0.02, 0.02, 0.01, 0.02, 0.01, 0.02, 0.02, 0.01,
//     0.01, 0.01, 0.01, 0.01, 0.01,
// ];
